% trainPriorWL: Learn reflectance statistics from training set
%
% trainPriorWL(lst,saveloc,true_l,clabel)
%
%   lst:     Cell (Nx1) containing paths of pre-computed statistics files
%   saveloc: Path of training .mat file passed to train (which
%            should be called first). This function will create a
%            new file called saveloc.w2.mat
%   true_l:  Nx3 Array containing ground truth illuminants for each
%            file. If this parameter is ommitted, all illuminants will
%            be assumed to be [1 1 1].
%   clabel:  Nx1 Array assigning a class number to each image. The
%            class number must be an integer from 1 to CMAX. CMAX
%            different priors will be learnt.
function trainPriorWL(lst,saveloc,true_l,clabel)

  cov_t = [];
  load('-mat',saveloc);
  saveloc = [saveloc '.w2.mat'];saveloc2 = saveloc;
  
  num_bands = size(cov_t,3);
  true_l = true_l ./ repmat(sqrt(sum(true_l.^2,2)),[1 3]);
    
  numC = max(clabel);
  
  Qc = cell(numC,1);
  
  for i = 1:numC
    w = 1./true_l(find(clabel == i),:);
    Q = (w' * w)/size(w,1); Q = inv(Q);
    Qc{i} = Q;
  end;
  
  A = zeros(3,3,length(lst));
    
  for i = 1:length(lst)
      fprintf('\r File %d of %d         ',i,length(lst));
      fouts = loadpc(lst{i});
      
      [Ai,NK] = getA(fouts,cov_t,true_l(i,:));
      ell = atoell(Ai);
      A(:,:,i) = Ai*NK / sum(ell.^2);
  end;
  fprintf('\n');
  
  clsub = clabel(1:end);
  
  alphaC = zeros(numC,1);
  for categ = 1:numC
    
    Ac = A(:,:,find(clsub == categ));
    Q = Qc{categ};
    lft = 1; rght = 14;
    for hier = 1:6
      alphas = linspace(lft,rght,4);
      score = zeros(size(alphas));
      for i = 1:length(alphas)
        for j = 1:size(Ac,3)
          score(i) = score(i) + ferr(true_l(j,:),atoell(Ac(:,:,j)+Q*10^alphas(i)));
        end;
      end;
      [score,idx] = sort(score); alpha = alphas(idx(1));
      lft = alpha - 0.8*(alphas(2)-alphas(1));
      rght = alpha + 0.8*(alphas(2)-alphas(1));
    end;
    alphaC(categ) = 10^alpha;
  end;
  
  save('-mat',saveloc2,'cov_t','alphaC','Qc');
    
function g = ferr(lo,l)
  lo = lo(:) / sqrt(sum(lo.^2));
  l = l(:) / sqrt(sum(l.^2));
  g = acosd(sum(l.*lo));

function [A,NK] = getA(fouts,si,m)
  A = zeros(3,3);
  NK = 0;
  for i = 1:length(fouts)
    d = fouts{i};
    
    NK = NK + size(d,1);
    dm = d * diag(1./m);
    wts = sum(dm .* (dm * si(:,:,i)),2); wts = abs(wts).^(0.25);
    
    d = d ./ repmat(max(wts,10^-24),[1 3]);
    A = A + (d'*d) .* si(:,:,i);
  end;
  A = 4*A/NK;
  
function ell = atoell(A)
  ell = sqrt(diag(A));
  C = inf;
  for it = 1:10
    for i = 1:3
      sj = sum(A(:,i) ./ ell(:)) - A(i,i)/ell(i);
      ell(i) = 0.5*(sj + sqrt(sj^2 + 4*A(i,i)));
    end;
    Cn = sum(log(ell)) + 0.5*ell'*A*ell;
    if Cn > (1-10^-4)*C
      break;
    end;
    C = Cn;
  end;
